"""
Visualization script for initial figures comparing gs_acc to q_acc.
"""

# %%
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

matplotlib.rcParams["text.usetex"] = True
sns.set_style("whitegrid")

data = pd.read_csv("results/final/sim_final.csv").iloc[:, 1:]


# %% Restructure data to be long-form
# Current:
#   id_vars: n_label, gs_acc, q_acc
#   columns: so_bias, dsl_bias, so_rmse, dsl_rmse, so_coverage, dsl_coverage
# Target:
#   n_label, gs_acc, q_acc, metric (bias, rmse, coverage), estimator (so, dsl), value
# Melt the DataFrame
plot_data = pd.melt(
    data,
    id_vars=["n_label", "gs_acc", "q_acc"],
    value_vars=[
        "so_bias",
        "dsl_bias",
        "so_rmse",
        "dsl_rmse",
        "so_coverage",
        "dsl_coverage",
    ],
    var_name="variable",
    value_name="value",
)

# Split the 'variable' column into 'metric' and 'estimator'
plot_data[["estimator", "metric"]] = plot_data["variable"].str.split("_", expand=True)

# Drop the original 'variable' column
plot_data = plot_data.drop(columns=["variable"])

# Reorder the columns to match the target format
plot_data = plot_data[["n_label", "gs_acc", "q_acc", "metric", "estimator", "value"]]


# %% Visualize
## Get organizing values
q_acc_vals = sorted(data["q_acc"].unique())
gs_acc_vals = sorted(data["gs_acc"].unique())
n_label_vals = sorted(data["n_label"].unique())
metrics = ["bias", "rmse", "coverage"]
estimators = ["so", "dsl"]

fig, axes = plt.subplots(
    len(metrics), len(q_acc_vals), figsize=(15, 6), sharex=True, sharey="row"
)
fig.suptitle(
    "SO vs DSL Comparison for Varying Surrogate gs_accuracy and Error Distribution\n"
    "Using Won et al. Data",
    y=1.0,
)
for col, q_acc in enumerate(q_acc_vals):
    for row, metric in enumerate(metrics):
        ax = axes[row, col]
        if row == len(metrics) - 1:
            # ax.set_xlabel(f"$Pr[Q \\ne Y| Y=1]={q_acc}$")
            ax.set_xlabel(f"N Labelled\n${q_acc}$")
        if col == 0:
            ax.set_ylabel(metric.capitalize())
        for estimator in estimators:
            for gs_acc_val in gs_acc_vals:
                df = plot_data.loc[
                    plot_data["metric"].eq(metric)
                    & plot_data["q_acc"].eq(q_acc)
                    & plot_data["estimator"].eq(estimator)
                    & plot_data["gs_acc"].eq(gs_acc_val),
                    :,
                ]
                color_map = sns.color_palette(
                    {"so": "Blues", "dsl": "Reds"}[estimator], n_colors=len(gs_acc_vals)
                )
                ax.plot(
                    df["n_label"],
                    df["value"],
                    label=f"{estimator} (${gs_acc_val}$)",
                    color=color_map[gs_acc_vals.index(gs_acc_val)],
                    marker={"so": "o", "dsl": "x"}[estimator],
                )
            ax.set_xticks(n_label_vals)
            ax.set_xticklabels(n_label_vals, rotation=45, ha="right")
            match metric:
                case "bias":
                    ax.axhline(0, color="black", linestyle="--")
                case "rmse":
                    ax.set_yscale("log")
                case "coverage":
                    ax.set_ylim(0, 1)
                    ax.axhline(0.95, color="black", linestyle="--")
fig.supxlabel(r"Q Accuracy", y=-0.05)
ax.legend(bbox_to_anchor=(1, 2.5), ncol=1, title="Estimator (gs_accuracy)")


# %% Different version - separate figure for each metric
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("bias"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")#, y=-0.02)
g.figure.supylabel(r"Mean Abs Bias")#, x=-0.02)
g.figure.suptitle("SO vs DSL Bias for Varying Gold-Standard and Surrogate Accuracy", y=1.05)

# %%
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("rmse"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")#, y=-0.02)
g.figure.supylabel(r"Mean RMSE")#, x=-0.02)
g.figure.suptitle("SO vs DSL RMSE for Varying Gold-Standard and Surrogate Accuracy", y=1.05)


# %%
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("coverage"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.map(plt.axhline, y=0.95, color="black", linestyle="--")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")#, y=-0.02)
g.figure.supylabel(r"Coverage")#, x=-0.02)
g.figure.suptitle("SO vs DSL Coverage for Varying Gold-Standard and Surrogate Accuracy", y=1.05)